"""
This file is based on code from @codemirror/state and @codemirror/collab.

MIT License

Copyright (C) 2018-2021 by Marijn Haverbeke <marijn@haverbeke.berlin> and others

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import dataclasses
import itertools
from difflib import SequenceMatcher

from sysreptor.utils.utils import get_at


class CollabStr:
    """
    Wrapper class for strings that handles unicode characters like in JavaScript strings.
    Python uses UTF-32 code units internally, while JavaScript uses UTF-16 code units.
    This results in discrepancies for string length and indexing for unicode characters larger than 16 bits (e.g. emojis).
    See https://hsivonen.fi/string-length/
    """

    def __init__(self, py_str) -> None:
        if isinstance(py_str, CollabStr):
            self.str_bytes = py_str.str_bytes
        elif isinstance(py_str, bytes):
            self.str_bytes = py_str
        else:
            self.str_bytes = py_str \
                .encode('utf-16-le')

    def __str__(self) -> str:
        return self.str_bytes.decode('utf-16-le')

    def __repr__(self) -> str:
        return repr(str(self))

    def __len__(self):
        return len(self.str_bytes) // 2

    def __getitem__(self, key):
        if isinstance(key, slice):
            return CollabStr(self.str_bytes[slice(
                key.start * 2 if key.start is not None else None,
                key.stop * 2 if key.stop is not None else None,
                key.step * 2 if key.step is not None else None,
            )])
        elif isinstance(key, int):
            return CollabStr(self.str_bytes[key * 2:((key * 2) + 2) or None])
        else:
            raise TypeError('Invalid argument type')

    def __add__(self, other):
        if isinstance(other, CollabStr):
            return CollabStr(self.str_bytes + other.str_bytes)
        elif isinstance(other, str):
            return self + CollabStr(other)
        else:
            raise TypeError('Invalid argument type')

    def __eq__(self, value: object) -> bool:
        if isinstance(value, CollabStr):
            return self.str_bytes == value.str_bytes
        elif isinstance(value, str):
            return self.str_bytes == CollabStr(value).str_bytes
        else:
            return self.str_bytes == value

    def __hash__(self) -> int:
        return hash(self.str_bytes)

    def __iter__(self):
        return map(lambda b: CollabStr(bytes(b)), itertools.batched(self.str_bytes, 2))

    def join(self, iterable):
        return CollabStr(self.str_bytes.join(s.str_bytes for s in iterable))


@dataclasses.dataclass
class ChangeSet:
    sections: list[int]
    inserted: list[CollabStr]

    @property
    def empty(self):
        """
        False when there are actual changes in this set.
        """
        return len(self.sections) == 0 or (len(self.sections) == 2 and self.sections[1] < 0)

    @property
    def length(self):
        """
        The length of the document before the change.
        """
        result = 0
        for i in range(0, len(self.sections), 2):
            result += self.sections[i]
        return result

    @classmethod
    def from_dict(cls, changes: list):
        sections = []
        inserted = []
        for i, part in enumerate(changes):
            if isinstance(part, int):
                sections.extend([part, -1])
            elif not isinstance(part, list) or len(part) == 0 or not isinstance(part[0], int) or not all(map(lambda e: isinstance(e, str|CollabStr), part[1:])):
                raise ValueError('Invalid change')
            else:
                while len(inserted) <= i:
                    inserted.append(CollabStr(''))  # Text.empty
                inserted[i] = CollabStr('\n').join(map(CollabStr, part[1:]))
                sections.extend([part[0], len(inserted[i])])
        return ChangeSet(sections=sections, inserted=inserted)

    def to_dict(self):
        """
        Serialize this change set to a JSON-representable value.
        """
        parts = []
        for i in range(0, len(self.sections), 2):
            i_len = self.sections[i]
            ins = self.sections[i + 1]
            if ins < 0:
                parts.append(i_len)
            elif ins == 0:
                parts.append([i_len])
            else:
                parts.append([i_len] + str(self.inserted[i >> 1]).split('\n'))
        return parts

    def compose(self, other: 'ChangeSet'):
        """
        Compute the combined effect of applying another set of changes
        after this one. The length of the document after this set should
        match the length before `other`.
        """
        if self.empty:
            return other
        elif other.empty:
            return self
        else:
            return compose_sets(self, other)

    def map(self, other: 'ChangeSet', before=False):
        """
        Given another change set starting in the same document, maps this
        change set over the other, producing a new change set that can be
        applied to the document produced by applying `other`. When
        `before` is `true`, order changes as if `this` comes before
        `other`, otherwise (the default) treat `other` as coming first.

        Given two changes `A` and `B`, `A.compose(B.map(A))` and
        `B.compose(A.map(B, true))` will produce the same document. This
        provides a basic form of [operational
        transformation](https://en.wikipedia.org/wiki/Operational_transformation),
        and can be used for collaborative editing.
        """
        if other.empty:
            return self
        else:
            return map_set(self, other, before)

    def map_pos(self, pos: int, assoc = -1):
        posA = 0
        posB = 0
        for i in range(0, len(self.sections), 2):
            i_len = self.sections[i]
            ins = self.sections[i + 1]
            endA = posA + i_len
            if ins < 0:
                if endA > pos:
                    return posB + (pos - posA)
                posB += i_len
            else:
                if endA > pos or endA == pos and assoc < 0 and i_len == 0:
                    return posB if (pos == posA or assoc < 0) else posB + ins
                posB += ins
            posA = endA
        if pos > posA:
            raise ValueError(f'Position {pos} is out of range for ChangeSet of length {posA}')
        return posB

    def iter_changes(self, individual):
        pos_a = 0
        pos_b = 0
        i = 0
        while i < len(self.sections):
            i_len = self.sections[i]
            ins = self.sections[i + 1]
            i += 2
            if ins < 0:
                pos_a += i_len
                pos_b += i_len
            else:
                end_a = pos_a
                end_b = pos_b
                text = CollabStr('')
                while True:
                    end_a += i_len
                    end_b += ins
                    if ins and self.inserted is not None:
                        text += self.inserted[(i - 2) >> 1]
                    if individual or i == len(self.sections) or self.sections[i + 1] < 0:
                        break
                    i_len = self.sections[i]
                    ins = self.sections[i + 1]
                    i += 2
                yield (pos_a, end_a, pos_b, end_b, text)
                pos_a = end_a
                pos_b = end_b


    def apply(self, doc: str):
        """
        Apply the changes in this set to a document, returning the new
        document.
        """
        # Normalize line breaks to get consistent positions
        doc = CollabStr(doc.replace('\r\n', '\n'))

        if self.length != len(doc):
            raise ValueError('Applying change set to a document with the wrong length')
        for (from_a, to_a, from_b, _to_b, text) in self.iter_changes(False):
            doc = doc[:from_b] + text + doc[from_b + (to_a - from_a):]
        return str(doc)

    @classmethod
    def from_diff(cls, text_before, text_after):
        return ChangeSet.from_dict(list(diff_lines(text_before.replace('\r\n', '\n'), text_after.replace('\r\n', '\n'))))


@dataclasses.dataclass
class SelectionRange:
    anchor: int
    head: int

    @property
    def empty(self):
        return self.head == self.anchor

    @property
    def from_(self):
        return min(self.anchor, self.head)

    @property
    def to(self):
        return max(self.anchor, self.head)

    @classmethod
    def from_dict(cls, r: dict):
        if not isinstance(r, dict):
            raise ValueError('Invalid selection range')

        anchor = r.get('anchor', r.get('from', r.get('from_')))
        head = r.get('head', r.get('to'))

        if not isinstance(anchor, int) or not isinstance(head, int) or anchor < 0 or head < 0:
            raise ValueError('Invalid selection range')

        return SelectionRange(anchor=anchor, head=head)

    def to_dict(self):
        return {
            'anchor': self.anchor,
            'head': self.head,
        }

    def map(self, change: ChangeSet):
        if self.empty:
            anchor = head = change.map_pos(self.anchor, assoc=-1)
        else:
            assoc = 1 if self.anchor < self.head else -1
            anchor = change.map_pos(self.anchor, assoc=assoc)
            head = change.map_pos(self.head, assoc * -1)
        return SelectionRange(head=head, anchor=anchor)


@dataclasses.dataclass
class EditorSelection:
    ranges: list[SelectionRange]
    main: int = 0

    @classmethod
    def from_dict(cls, d: dict):
        if not d or not isinstance(d.get('ranges'), list) or not isinstance(d.get('main'), int) or not (0 <= d['main'] < len(d['ranges'])):
            raise ValueError('Invalid selection')
        return EditorSelection(ranges=[SelectionRange.from_dict(r) for r in d['ranges']], main=d['main'])

    def to_dict(self):
        return {
            'ranges': [r.to_dict() for r in self.ranges],
            'main': self.main,
        }

    def map(self, change: ChangeSet):
        if change.empty:
            return self
        return EditorSelection(
            ranges=[range.map(change) for range in self.ranges],
            main=self.main,
        )


@dataclasses.dataclass
class Update:
    client_id: str
    version: float
    changes: ChangeSet

    def to_dict(self):
        return {
            'changes': self.changes.to_dict(),
        }

    @classmethod
    def from_dict(cls, u: dict):
        if not isinstance(u, dict) or not isinstance(u.get('client_id'), str) or not isinstance(u.get('version'), int|float) or \
           not isinstance(u.get('changes'), list):
            raise ValueError('Invalid update')
        return Update(
            client_id=u['client_id'],
            version=float(u['version']),
            changes=ChangeSet.from_dict(u['changes']),
        )


class SectionIter:
    def __init__(self, set: ChangeSet):
        self.set = set
        self.i = 0
        self.len = 0
        self.off = 0
        self.ins = 0
        self.next()

    @property
    def done(self):
        return self.ins == -2

    @property
    def text(self):
        index = (self.i - 2) >> 1
        return CollabStr('') if index >= len(self.set.inserted) else self.set.inserted[index]

    @property
    def len2(self):
        return self.len if self.ins < 0 else self.ins

    def next(self):
        if self.i < len(self.set.sections):
            self.len = self.set.sections[self.i]
            self.ins = self.set.sections[self.i + 1]
            self.i += 2
        else:
            self.len = 0
            self.ins = -2
        self.off = 0

    def forward(self, i_len: int):
        if i_len == self.len:
            self.next()
        else:
            self.len -= i_len
            self.off += i_len

    def forward2(self, i_len: int):
        if self.ins == -1:
            self.forward(i_len)
        elif i_len == self.ins:
            self.next()
        else:
            self.ins -= i_len
            self.off += i_len

    def text_bit(self, i_len: int | None=None):
        index = (self.i - 2) >> 1
        if index >= len(self.set.inserted) and not i_len:
            return CollabStr('')
        elif i_len is not None:
            return self.set.inserted[index][self.off:self.off + i_len]
        else:
            return self.set.inserted[index][self.off:]


def add_section(sections: list[int], i_len: int, ins: int, force_join=False):
    if i_len == 0 and ins < 0:
        return

    last = len(sections) - 2
    if last >= 0 and ins < 0 and ins == sections[last + 1]:
        sections[last] += i_len
    elif i_len == 0 and get_at(sections, last) == 0:
        sections[last + 1] += ins
    elif force_join:
        sections[last] += i_len
        sections[last + 1] += ins
    else:
        sections.extend([i_len, ins])


def add_insert(values: list[CollabStr], sections: list[int], value: CollabStr):
    if len(value) == 0:
        return

    index = (len(sections) - 2) >> 1
    if index < len(values):
        values[-1] += value
    else:
        while len(values) < index:
            values.append('')
        values.append(value)


def map_set(setA: ChangeSet, setB: ChangeSet, before: bool):
    """
    Produce a copy of setA that applies to the document after setB
    has been applied (assuming both start at the same document).
    """
    sections = []
    insert = []

    # Iterate over both sets in parallel. inserted tracks, for changes
    # in A that have to be processed piece-by-piece, whether their
    # content has been inserted already, and refers to the section index.
    a = SectionIter(setA)
    b = SectionIter(setB)
    inserted = -1
    while True:
        if a.ins == -1 and b.ins == -1:
            # Move across ranges skipped by both sets.
            i_len = min(a.len, b.len)
            add_section(sections, i_len, -1)
            a.forward(i_len)
            b.forward(i_len)
        elif b.ins >= 0 and (a.ins < 0 or inserted == a.i or a.off == 0 and (b.len < a.len or b.len == a.len and not before)):
            # If there's a change in B that comes before the next change in
            # A (ordered by start pos, then len, then before flag),
            # skip that (and process any changes in A it covers).
            i_len = b.len
            add_section(sections, b.ins, -1)
            while i_len > 0:
                piece = min(a.len, i_len)
                if a.ins >= 0 and inserted < a.i and a.len <= piece:
                    add_section(sections, 0, a.ins)
                    if insert is not None:
                        add_insert(insert, sections, a.text)
                    inserted = a.i
                a.forward(piece)
                i_len -= piece
            b.next()
        elif a.ins >= 0:
            # Process the part of a change in A up to the start of the next
            # non-deletion change in B (if overlapping).
            i_len = 0
            left = a.len
            while left > 0:
                if b.ins == -1:
                    piece = min(left, b.len)
                    i_len += piece
                    left -= piece
                    b.forward(piece)
                elif b.ins == 0 and b.len < left:
                    left -= b.len
                    b.next()
                else:
                    break
            add_section(sections, i_len, a.ins if inserted < a.i else 0)
            if insert is not None and inserted < a.i:
                add_insert(insert, sections, a.text)
            inserted = a.i
            a.forward(a.len - left)
        elif a.done and b.done:
            return ChangeSet(sections, insert)
        else:
            raise ValueError('Mismatched change set lengths')


def compose_sets(setA: ChangeSet, setB: ChangeSet):
    sections = []
    insert = []

    a = SectionIter(setA)
    b = SectionIter(setB)
    open = False
    while True:
        if a.done and b.done:
            return ChangeSet(sections, insert)
        elif a.ins == 0:
            # Deletion in A
            add_section(sections, a.len, 0, open)
            a.next()
        elif b.len == 0 and not b.done:
            # Insertion in B
            add_section(sections, 0, b.ins, open)
            if insert is not None:
                add_insert(insert, sections, b.text)
            b.next()
        elif a.done or b.done:
            raise ValueError('Mismatched change set lengths')
        else:
            i_len = min(a.len2, b.len)
            section_len = len(sections)
            if a.ins == -1:
                ins_b = -1 if b.ins == -1 else \
                        0 if b.off else \
                        b.ins
                add_section(sections, i_len, ins_b, open)
                if insert is not None and ins_b:
                    add_insert(insert, sections, b.text)
            elif b.ins == -1:
                add_section(sections, 0 if a.off else a.len, i_len, open)
                if insert is not None:
                    add_insert(insert, sections, a.text_bit(i_len))
            else:
                add_section(sections, 0 if a.off else a.len, 0 if b.off else b.ins, open)
                if insert is not None and not b.off:
                    add_insert(insert, sections, b.text)

            open = (a.ins > i_len or (b.ins >= 0 and b.len > i_len)) and (open or len(sections) > section_len)
            a.forward2(i_len)
            b.forward(i_len)


def diff_lines(text_before: str, text_after: str):
    lines_before = text_before.splitlines(keepends=True)
    lines_after = text_after.splitlines(keepends=True)

    idx_before = 0
    for tag, alo, ahi, blo, bhi in SequenceMatcher(a=lines_before, b=lines_after).get_opcodes():
        # Use CollabStr to calculate indices and lengths to handle unicode characters correctly
        a_str = CollabStr(''.join(lines_before[alo:ahi]))
        b_str = CollabStr(''.join(lines_after[blo:bhi]))
        idx_after = idx_before + len(a_str)

        match tag:
            case 'equal':
                yield idx_after - idx_before
            case 'insert':
                yield [0, b_str]
            case 'delete':
                yield [idx_after - idx_before, '']
            case 'replace':
                yield from diff_characters(str(a_str), str(b_str))
        idx_before = idx_after


def diff_characters(text_before: str, text_after: str):
    idx_before = 0
    # Calculate diff using python strings to not split unicode characters
    for tag, alo, ahi, blo, bhi in SequenceMatcher(a=text_before, b=text_after).get_opcodes():
        # Use CollabStr to calculate indices and lengths to handle unicode characters correctly
        a_str = CollabStr(text_before[alo:ahi])
        b_str = CollabStr(text_after[blo:bhi])
        idx_after = idx_before + len(a_str)
        if tag == 'equal':
            yield idx_after - idx_before
        else:
            yield [idx_after - idx_before, b_str]


def rebase_updates(updates: list[Update], selection: EditorSelection | None, over: list[Update]) -> tuple[list[Update], EditorSelection | None]:
    """
    Rebase and deduplicate an array of client-submitted updates that
    came in with an out-of-date version number. `over` should hold the
    updates that were accepted since the given version (or at least
    their change descs and client IDs). Will return an array of
    updates that, firstly, has updates that were already accepted
    filtered out, and secondly, has been moved over the other changes
    so that they apply to the current document version.
    """
    if not updates or not over:
        return updates, selection

    changes = None
    skip = 0
    version = None
    for update in over:
        other = updates[skip] if skip < len(updates) else None
        if other and other.client_id == update.client_id:
            if changes:
                changes = changes.map(other.changes, True)
            skip += 1
        else:
            if changes:
                changes = changes.compose(update.changes)
            else:
                changes = update.changes
        if version is None or update.version > version:
            version = update.version

    if skip:
        updates = updates[skip:]

    if not changes:
        return updates, selection
    else:
        out = []
        for update in updates:
            updated_changes = update.changes.map(changes)
            changes = changes.map(update.changes, True)
            out.append(Update(
                client_id=update.client_id,
                version=version,
                changes=updated_changes,
            ))
        if selection:
            selection = selection.map(changes)
        return out, selection
